{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-04-07 13:51:48--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 540689175 (516M) [application/zip]\n",
      "Saving to: ‘train_sample.zip’\n",
      "\n",
      "100%[======================================>] 540,689,175 10.6MB/s   in 53s    \n",
      "\n",
      "2021-04-07 13:52:42 (9.67 MB/s) - ‘train_sample.zip’ saved [540689175/540689175]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -qq train_sample.zip\n",
    "!\\rm train_sample.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-04-07 13:54:15--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip\n",
      "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
      "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1092637852 (1.0G) [application/zip]\n",
      "Saving to: ‘test_a.zip’\n",
      "\n",
      "100%[====================================>] 1,092,637,852 9.31MB/s   in 1m 50s \n",
      "\n",
      "2021-04-07 13:56:05 (9.45 MB/s) - ‘test_a.zip’ saved [1092637852/1092637852]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -qq test_a.zip\n",
    "!\\rm test_a.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境要求\n",
    "\n",
    "- TensorFlow的版本：2.0 + \n",
    "- keras\n",
    "- sklearn\n",
    "- librosa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 基本库\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载深度学习框架"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n",
      "  from numpy.core.umath_tests import inner1d\n"
     ]
    }
   ],
   "source": [
    "# 搭建分类模型所需要的库\n",
    "\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout\n",
    "from tensorflow.keras.utils import to_categorical \n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.svm import SVC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载音频处理库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
      "Collecting librosa\n",
      "  Using cached librosa-0.8.0-py3-none-any.whl\n",
      "Requirement already satisfied: decorator>=3.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (4.4.2)\n",
      "Collecting soundfile>=0.9.0\n",
      "  Using cached https://mirrors.aliyun.com/pypi/packages/eb/f2/3cbbbf3b96fb9fa91582c438b574cff3f45b29c772f94c400e2c99ef5db9/SoundFile-0.10.3.post1-py2.py3-none-any.whl (21 kB)\n",
      "Requirement already satisfied: numpy>=1.15.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.19.4)\n",
      "Collecting pooch>=1.0\n",
      "  Using cached https://mirrors.aliyun.com/pypi/packages/40/b9/9876662636ba451d4406543047c0b45ca5b4e830f931308c8274dad1db43/pooch-1.3.0-py3-none-any.whl (51 kB)\n",
      "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (0.24.0)\n",
      "Collecting audioread>=2.0.0\n",
      "  Using cached audioread-2.1.9-py3-none-any.whl\n",
      "Requirement already satisfied: numba>=0.43.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.53.1)\n",
      "Requirement already satisfied: joblib>=0.14 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.0.0)\n",
      "Collecting resampy>=0.2.2\n",
      "  Using cached resampy-0.2.2-py3-none-any.whl\n",
      "Requirement already satisfied: scipy>=1.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.5.4)\n",
      "Requirement already satisfied: llvmlite<0.37,>=0.36.0rc1 in /data/nas/workspace/envs/python3.6/site-packages (from numba>=0.43.0->librosa) (0.36.0)\n",
      "Requirement already satisfied: setuptools in /opt/conda/lib/python3.6/site-packages (from numba>=0.43.0->librosa) (51.1.1)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (20.8)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (2.25.1)\n",
      "Collecting appdirs\n",
      "  Using cached https://mirrors.aliyun.com/pypi/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\n",
      "Requirement already satisfied: six>=1.3 in /opt/conda/lib/python3.6/site-packages (from resampy>=0.2.2->librosa) (1.15.0)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.6/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)\n",
      "Requirement already satisfied: cffi>=1.0 in /opt/conda/lib/python3.6/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)\n",
      "Requirement already satisfied: pycparser in /opt/conda/lib/python3.6/site-packages (from cffi>=1.0->soundfile>=0.9.0->librosa) (2.20)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.6/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2.10)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (1.26.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)\n",
      "Installing collected packages: appdirs, soundfile, resampy, pooch, audioread, librosa\n",
      "Successfully installed appdirs-1.4.4 audioread-2.1.9 librosa-0.8.0 pooch-1.3.0 resampy-0.2.2 soundfile-0.10.3.post1\n"
     ]
    }
   ],
   "source": [
    "!pip install librosa --user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 其他库\n",
    "\n",
    "import os\n",
    "import librosa\n",
    "import librosa.display\n",
    "import glob "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 特征提取以及数据集的建立"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature = []\n",
    "label = []\n",
    "# 建立类别标签，不同类别对应不同的数字。\n",
    "label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,\n",
    "                  'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,\n",
    "                  'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,\n",
    "                  'soup': 18, 'wings': 19}\n",
    "label_dict_inv = {v:k for k,v in label_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "def extract_features(parent_dir, sub_dirs, max_file=10, file_ext=\"*.wav\"):\n",
    "    c = 0\n",
    "    label, feature = [], []\n",
    "    for sub_dir in sub_dirs:\n",
    "        for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件\n",
    "            \n",
    "           # segment_log_specgrams, segment_labels = [], []\n",
    "            #sound_clip,sr = librosa.load(fn)\n",
    "            #print(fn)\n",
    "            label_name = fn.split('/')[-2]\n",
    "            label.extend([label_dict[label_name]])\n",
    "            X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
    "            mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
    "            feature.extend([mels])\n",
    "            \n",
    "    return [feature, label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 45/45 [00:08<00:00,  4.85it/s]\n",
      "100%|██████████| 64/64 [00:12<00:00,  5.10it/s]\n",
      "100%|██████████| 48/48 [00:12<00:00,  2.48it/s]\n",
      "100%|██████████| 74/74 [00:21<00:00,  1.40it/s]\n",
      "100%|██████████| 49/49 [00:11<00:00,  3.69it/s]\n",
      "100%|██████████| 57/57 [00:13<00:00,  4.00it/s]\n",
      "100%|██████████| 27/27 [00:06<00:00,  4.55it/s]\n",
      "100%|██████████| 27/27 [00:05<00:00,  4.48it/s]\n",
      "100%|██████████| 57/57 [00:11<00:00,  4.95it/s]\n",
      "100%|██████████| 61/61 [00:13<00:00,  4.06it/s]\n",
      "100%|██████████| 65/65 [00:15<00:00,  4.30it/s]\n",
      "100%|██████████| 69/69 [00:17<00:00,  3.40it/s]\n",
      "100%|██████████| 43/43 [00:10<00:00,  4.39it/s]\n",
      "100%|██████████| 33/33 [00:07<00:00,  4.49it/s]\n",
      "100%|██████████| 75/75 [00:18<00:00,  3.57it/s]\n",
      "100%|██████████| 55/55 [00:14<00:00,  2.95it/s]\n",
      "100%|██████████| 47/47 [00:12<00:00,  3.46it/s]\n",
      "100%|██████████| 37/37 [00:10<00:00,  3.80it/s]\n",
      "100%|██████████| 32/32 [00:06<00:00,  4.69it/s]\n",
      "100%|██████████| 35/35 [00:08<00:00,  4.08it/s]\n"
     ]
    }
   ],
   "source": [
    "# 自己更改目录\n",
    "parent_dir = './train_sample/'\n",
    "save_dir = \"./\"\n",
    "folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits',\n",
    "                             'carrots','chips','chocolate','drinks','fries',\n",
    "                            'grapes','gummies','ice-cream','jelly','noodles','pickles',\n",
    "                            'pizza','ribs','salmon','soup','wings'])\n",
    "\n",
    "# 获取特征feature以及类别的label\n",
    "temp = extract_features(parent_dir,sub_dirs,max_file=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "temp = np.array(temp)\n",
    "data = temp.transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X的特征尺寸是： (1000, 128)\n",
      "Y的特征尺寸是： (1000,)\n"
     ]
    }
   ],
   "source": [
    "# 获取特征\n",
    "X = np.vstack(data[:, 0])\n",
    "\n",
    "# 获取标签\n",
    "Y = np.array(data[:, 1])\n",
    "print('X的特征尺寸是：',X.shape)\n",
    "print('Y的特征尺寸是：',Y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在Keras库中：to_categorical就是将类别向量转换为二进制（只有0和1）的矩阵类型表示\n",
    "Y = to_categorical(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 128)\n",
      "(1000, 20)\n"
     ]
    }
   ],
   "source": [
    "'''最终数据'''\n",
    "print(X.shape)\n",
    "print(Y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集的大小 750\n",
      "测试集的大小 250\n"
     ]
    }
   ],
   "source": [
    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)\n",
    "print('训练集的大小',len(X_train))\n",
    "print('测试集的大小',len(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train.reshape(-1, 16, 8, 1)\n",
    "X_test = X_test.reshape(-1, 16, 8, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 搭建CNN网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "\n",
    "# 输入的大小\n",
    "input_dim = (16, 8, 1)\n",
    "\n",
    "model.add(Conv2D(64, (3, 3), padding = \"same\", activation = \"tanh\", input_shape = input_dim))# 卷积层\n",
    "model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化\n",
    "model.add(Conv2D(128, (3, 3), padding = \"same\", activation = \"tanh\")) #卷积层\n",
    "model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层\n",
    "model.add(Dropout(0.1))\n",
    "model.add(Flatten()) # 展开\n",
    "model.add(Dense(1024, activation = \"tanh\"))\n",
    "model.add(Dense(20, activation = \"softmax\")) # 输出层：20个units输出20个类的概率\n",
    "\n",
    "# 编译模型，设置损失函数，优化方法以及评价标准\n",
    "model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_4 (Conv2D)            (None, 16, 8, 64)         640       \n",
      "_________________________________________________________________\n",
      "max_pooling2d_4 (MaxPooling2 (None, 8, 4, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_5 (Conv2D)            (None, 8, 4, 128)         73856     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_5 (MaxPooling2 (None, 4, 2, 128)         0         \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 4, 2, 128)         0         \n",
      "_________________________________________________________________\n",
      "flatten_2 (Flatten)          (None, 1024)              0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 1024)              1049600   \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 20)                20500     \n",
      "=================================================================\n",
      "Total params: 1,144,596\n",
      "Trainable params: 1,144,596\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "50/50 [==============================] - 3s 49ms/step - loss: 1.5272 - accuracy: 0.5973 - val_loss: 3.6820 - val_accuracy: 0.2960\n",
      "Epoch 2/20\n",
      "50/50 [==============================] - 2s 47ms/step - loss: 1.4931 - accuracy: 0.6013 - val_loss: 3.5160 - val_accuracy: 0.3400\n",
      "Epoch 3/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.9785 - accuracy: 0.7027 - val_loss: 3.2466 - val_accuracy: 0.3400\n",
      "Epoch 4/20\n",
      "50/50 [==============================] - 2s 42ms/step - loss: 0.6756 - accuracy: 0.7880 - val_loss: 3.0725 - val_accuracy: 0.4080\n",
      "Epoch 5/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.5808 - accuracy: 0.8333 - val_loss: 3.3328 - val_accuracy: 0.3880\n",
      "Epoch 6/20\n",
      "50/50 [==============================] - 2s 40ms/step - loss: 0.5391 - accuracy: 0.8507 - val_loss: 3.4656 - val_accuracy: 0.3640\n",
      "Epoch 7/20\n",
      "50/50 [==============================] - 2s 41ms/step - loss: 0.4948 - accuracy: 0.8520 - val_loss: 3.6385 - val_accuracy: 0.3680\n",
      "Epoch 8/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.4554 - accuracy: 0.8747 - val_loss: 3.8555 - val_accuracy: 0.3720\n",
      "Epoch 9/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.3798 - accuracy: 0.8960 - val_loss: 3.8773 - val_accuracy: 0.3880\n",
      "Epoch 10/20\n",
      "50/50 [==============================] - 2s 44ms/step - loss: 0.3407 - accuracy: 0.9040 - val_loss: 3.5860 - val_accuracy: 0.3680\n",
      "Epoch 11/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.3318 - accuracy: 0.9067 - val_loss: 4.1420 - val_accuracy: 0.3880\n",
      "Epoch 12/20\n",
      "50/50 [==============================] - 2s 41ms/step - loss: 0.3400 - accuracy: 0.8933 - val_loss: 4.0300 - val_accuracy: 0.4000\n",
      "Epoch 13/20\n",
      "50/50 [==============================] - 2s 42ms/step - loss: 0.3098 - accuracy: 0.9093 - val_loss: 4.0889 - val_accuracy: 0.3800\n",
      "Epoch 14/20\n",
      "50/50 [==============================] - 2s 47ms/step - loss: 0.3750 - accuracy: 0.8867 - val_loss: 4.2326 - val_accuracy: 0.4200\n",
      "Epoch 15/20\n",
      "50/50 [==============================] - 2s 43ms/step - loss: 0.3795 - accuracy: 0.9053 - val_loss: 4.5108 - val_accuracy: 0.3640\n",
      "Epoch 16/20\n",
      "50/50 [==============================] - 2s 45ms/step - loss: 0.3753 - accuracy: 0.8960 - val_loss: 4.2285 - val_accuracy: 0.4120\n",
      "Epoch 17/20\n",
      "50/50 [==============================] - 2s 41ms/step - loss: 0.3066 - accuracy: 0.9027 - val_loss: 4.4688 - val_accuracy: 0.3880\n",
      "Epoch 18/20\n",
      "50/50 [==============================] - 2s 45ms/step - loss: 0.3143 - accuracy: 0.9027 - val_loss: 4.8396 - val_accuracy: 0.3880\n",
      "Epoch 19/20\n",
      "50/50 [==============================] - 2s 42ms/step - loss: 0.2654 - accuracy: 0.9173 - val_loss: 4.2883 - val_accuracy: 0.4000\n",
      "Epoch 20/20\n",
      "50/50 [==============================] - 2s 44ms/step - loss: 0.2140 - accuracy: 0.9400 - val_loss: 4.3122 - val_accuracy: 0.4160\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fee3082a860>"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型\n",
    "model.fit(X_train, Y_train, epochs = 20, batch_size = 15, validation_data = (X_test, Y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 预测测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_features(test_dir, file_ext=\"*.wav\"):\n",
    "    feature = []\n",
    "    for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件\n",
    "        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
    "        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
    "        feature.extend([mels])\n",
    "    return feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [09:25<00:00,  2.90it/s]\n"
     ]
    }
   ],
   "source": [
    "X_test = extract_features('./test_a/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = np.vstack(X_test)\n",
    "predictions = model.predict(X_test.reshape(-1, 16, 8, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = np.argmax(predictions, axis = 1)\n",
    "preds = [label_dict_inv[x] for x in preds]\n",
    "\n",
    "path = glob.glob('./test_a/*.wav')\n",
    "result = pd.DataFrame({'name':path, 'label': preds})\n",
    "\n",
    "result['name'] = result['name'].apply(lambda x: x.split('/')[-1])\n",
    "result.to_csv('submit.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000\n"
     ]
    }
   ],
   "source": [
    "!ls ./test_a/*.wav | wc -l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2001 submit.csv\n"
     ]
    }
   ],
   "source": [
    "!wc -l submit.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "tianchi_metadata": {
   "competitions": [],
   "datasets": [
    {
     "id": "96728",
     "title": "获取数据集标题失败"
    }
   ],
   "description": "",
   "notebookId": "185525",
   "source": "dsw"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
